#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Various classes for spatially correlated flat-fading channels"""

from abc import abstractmethod
import tensorflow as tf
from tensorflow.experimental.numpy import swapaxes
from sionna.phy.block import Object
from sionna.phy.utils import expand_to_rank

class SpatialCorrelation(Object):
    # pylint: disable=line-too-long
    r"""Abstract class that defines an interface for spatial correlation functions

    The :class:`~sionna.phy.channel.FlatFadingChannel` model can be configured with a
    spatial correlation model.

    Parameters
    ----------
    precision : `None` (default) | "single" | "double"
        Precision used for internal calculations and outputs.
        If set to `None`,
        :attr:`~sionna.phy.config.Config.precision` is used.

    Input
    -----
    h : `tf.complex`
        Tensor of arbitrary shape containing spatially uncorrelated
        channel coefficients

    Output
    ------
    h_corr : `tf.complex`
        Tensor of the same shape as ``h`` containing the spatially
        correlated channel coefficients
    """
    @abstractmethod
    def __call__(self, h, *args, **kwargs):
        return NotImplemented

class KroneckerModel(SpatialCorrelation):
    # pylint: disable=line-too-long
    r"""Kronecker model for spatial correlation

    Given a batch of matrices :math:`\mathbf{H}\in\mathbb{C}^{M\times K}`,
    :math:`\mathbf{R}_\text{tx}\in\mathbb{C}^{K\times K}`, and
    :math:`\mathbf{R}_\text{rx}\in\mathbb{C}^{M\times M}`, this function
    will generate the following output:

    .. math::

        \mathbf{H}_\text{corr} = \mathbf{R}^{\frac12}_\text{rx} \mathbf{H} \mathbf{R}^{\frac12}_\text{tx}

    Note that :math:`\mathbf{R}_\text{tx}\in\mathbb{C}^{K\times K}` and :math:`\mathbf{R}_\text{rx}\in\mathbb{C}^{M\times M}`
    must be positive semi-definite, such as the ones generated by
    :meth:`~sionna.phy.channel.exp_corr_mat`.

    Parameters
    ----------
    r_tx : [..., K, K], `tf.complex`
        Transmit correlation matrices. If
        the rank of ``r_tx`` is smaller than that of the input ``h``,
        it will be broadcast.

    r_rx : [..., M, M], `tf.complex`
        Receive correlation matrices. If
        the rank of ``r_rx`` is smaller than that of the input ``h``,
        it will be broadcast.

    precision : `None` (default) | "single" | "double"
        Precision used for internal calculations and outputs.
        If set to `None`,
        :attr:`~sionna.phy.config.Config.precision` is used.

    Input
    -----
    h : [..., M, K], `tf.complex`
        Spatially uncorrelated channel coeffficients

    Output
    ------
    h_corr : [..., M, K], `tf.complex`
        Spatially correlated channel coefficients
    """
    def __init__(self, r_tx=None, r_rx=None, precision=None):
        super().__init__(precision=None)
        self.r_tx = r_tx
        self.r_rx = r_rx

    @property
    def r_tx(self):
        r"""
        [..., K, K], `tf.complex` : Get/set transmit correlation matrices
        """
        return self._r_tx

    @r_tx.setter
    def r_tx(self, value):
        self._r_tx = value

    @property
    def r_rx(self):
        r"""
        [..., M, M], `tf.complex` : Get/set receive correlation matrices
        """
        return self._r_rx

    @r_rx.setter
    def r_rx(self, value):
        self._r_rx = value

    def __call__(self, h):
        if self.r_tx is not None:
            l_tx = tf.linalg.cholesky(self.r_tx)
            h = tf.matmul(h, l_tx, adjoint_b=True)

        if self.r_rx is not None:
            l_rx = tf.linalg.cholesky(self.r_rx)
            h = tf.matmul(l_rx, h)

        return h

class PerColumnModel(SpatialCorrelation):
        # pylint: disable=line-too-long
    r"""Per-column model for spatial correlation

    Given a batch of matrices :math:`\mathbf{H}\in\mathbb{C}^{M\times K}`
    and correlation matrices :math:`\mathbf{R}_k\in\mathbb{C}^{M\times M}, k=1,\dots,K`,
    this function will generate the output :math:`\mathbf{H}_\text{corr}\in\mathbb{C}^{M\times K}`,
    with columns

    .. math::

        \mathbf{h}^\text{corr}_k = \mathbf{R}^{\frac12}_k \mathbf{h}_k,\quad k=1, \dots, K

    where :math:`\mathbf{h}_k` is the kth column of :math:`\mathbf{H}`.
    Note that all :math:`\mathbf{R}_k\in\mathbb{C}^{M\times M}` must
    be positive semi-definite, such as the ones generated
    by :meth:`~sionna.phy.channel.one_ring_corr_mat`.

    This model is typically used to simulate a MIMO channel between multiple
    single-antenna users and a base station with multiple antennas.
    The resulting SIMO channel for each user has a different spatial correlation.

    Parameters
    ----------
    r_rx : [..., M, M], `tf.complex`
        Receive correlation matrices. If
        the rank of ``r_rx`` is smaller than that of the input ``h``,
        it will be broadcast. For a typically use of this model, ``r_rx``
        has shape [..., K, M, M], i.e., a different correlation matrix for each
        column of ``h``.

    precision : `None` (default) | "single" | "double"
        Precision used for internal calculations and outputs.
        If set to `None`,
        :attr:`~sionna.phy.config.Config.precision` is used.

    Input
    -----
    h : [..., M, K], `tf.complex`
        Spatially uncorrelated channel coeffficients

    Output
    ------
    h_corr : [..., M, K], tf.complex
        Spatially correlated channel coefficients
    """
    def __init__(self, r_rx, precision=None):
        super().__init__(precision=precision)
        self.r_rx = r_rx

    @property
    def r_rx(self):
        r"""
        [..., M, M], `tf.complex` : Get/set receive correlation matrices
        """
        return self._r_rx

    @r_rx.setter
    def r_rx(self, value):
        self._r_rx = value

    def __call__(self, h):
        if self.r_rx is not None:
            l_rx = tf.linalg.cholesky(self.r_rx)
            h = swapaxes(h, -2, -1)
            h = tf.expand_dims(h, -1)
            l_rx = expand_to_rank(l_rx, tf.rank(h), 0)
            h = tf.matmul(l_rx, h)
            h = tf.squeeze(h, -1)
            h = swapaxes(h, -2, -1)

        return h
